#define ACCURACY_IDX(k) (s + spatial_dim * (k + channels * n))
template <typename T>
__device__ void accuracy_forward(T *pred, T *labels, T *acc_buf, int num, int channels, int spatial_dim) {
  int n = threadIdx.x + blockIdx.x * blockDim.x;
  int s = threadIdx.y + blockIdx.y * blockDim.y;
  if (n >= num || s >= spatial_dim)
    return;

  int idx = s + spatial_dim * n;
  int label = static_cast<int>(labels[idx]);

  T maxval = pred[ACCURACY_IDX(0)];
  int maxidx = 0;
  for (int i = 1; i < channels; ++i) {
    if (pred[ACCURACY_IDX(i)] > maxval) {
      maxval = pred[ACCURACY_IDX(i)];
      maxidx = i;
    }
  }
  if (maxidx == label) {
    acc_buf[idx] = 1;
  } else {
    acc_buf[idx] = 0;
  }
}

extern "C" {
  __global__ void accuracy_forward_float(float *pred, float *label, float *acc_buf, int num, int channels, int spatial_dim) {
    accuracy_forward(pred, label, acc_buf, num, channels, spatial_dim);
  }
  __global__ void accuracy_forward_double(double *pred, double *label, double *acc_buf, int num, int channels, int spatial_dim) {
    accuracy_forward(pred, label, acc_buf, num, channels, spatial_dim);
  }
}

// vim: ft=cuda
